# AUTHOR: DING
# -*- codeing = utf-8 -*-
# @Time: 2024/3/1 14:01
# @Author: 86139
# @Site: 
# @File: 25-test.py
# @Software: PyCharm
# tensorboard --logdir=pytorch/logs --port=6007
import torchvision
from PIL import Image
import torch
from model import *
import cv2 as cv

# 模型的验证
pth = "./images/003.png"
img = cv.imread(pth)
img = cv.resize(img, (32, 32))
print(img.shape)
# img = Image.open(pth).convert('RGB')
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
img = transform(img)
print(img.shape)
img = torch.reshape(img, shape=(1, 3, 32, 32))

network = MyModule()
network.load_state_dict(torch.load("22-model.pth"))

with torch.no_grad():
    output = network(img)
    print(output)
print(output.argmax(1))
